# Copyright (c) Alibaba, Inc. and its affiliates.
from dataclasses import dataclass, field
from typing import List, Optional, Union

import transformers
from packaging import version

transformers_ge_4_52 = version.parse(transformers.__version__) >= version.parse("4.52")


class LLMModelArch:
    qwen = "qwen"
    llama = "llama"
    internlm2 = "internlm2"
    chatglm = "chatglm"
    deepseek_v2 = "deepseek_v2"
    baichuan = "baichuan"

    yuan = "yuan"
    codefuse = "codefuse"
    phi2 = "phi2"
    phi3 = "phi3"
    phi3_small = "phi3_small"
    telechat = "telechat"
    dbrx = "dbrx"


class MLLMModelArch:
    qwen_vl = "qwen_vl"
    qwen_audio = "qwen_audio"
    qwen2_vl = "qwen2_vl"
    qwen2_audio = "qwen2_audio"
    qwen2_5_omni = "qwen2_5_omni"

    cogvlm = "cogvlm"
    glm4v = "glm4v"
    glm4_1v = "glm4_1v"
    glm_edge_v = "glm_edge_v"

    llama3_1_omni = "llama3_1_omni"
    llama3_2_vision = "llama3_2_vision"

    llava_hf = "llava_hf"
    llava_next_video_hf = "llava_next_video_hf"

    llava_llama = "llava_llama"
    llava_mistral = "llava_mistral"

    xcomposer = "xcomposer"
    internvl = "internvl"
    minicpmv = "minicpmv"
    deepseek_vl = "deepseek_vl"
    deepseek_vl2 = "deepseek_vl2"
    deepseek_janus = "deepseek_janus"

    mplug_owl2 = "mplug_owl2"
    mplug_owl2_1 = "mplug_owl2_1"
    mplug_owl3 = "mplug_owl3"
    doc_owl2 = "doc_owl2"

    phi3_vision = "phi3_vision"
    phi4_multimodal = "phi4_multimodal"
    florence = "florence"
    idefics3 = "idefics3"

    got_ocr2 = "got_ocr2"

    ovis1_6 = "ovis1_6"
    molmo = "molmo"
    emu3_chat = "emu3_chat"
    megrez_omni = "megrez_omni"
    valley = "valley"
    gemma3n = "gemma3n"
    mistral_2503 = "mistral_2503"
    keye_vl = "keye_vl"


class ModelArch(LLMModelArch, MLLMModelArch):
    pass


@dataclass
class ModelKeys:

    arch_name: str = None

    embedding: str = None
    module_list: str = None
    lm_head: str = None

    q_proj: str = None
    k_proj: str = None
    v_proj: str = None
    o_proj: str = None
    attention: str = None

    mlp: str = None
    down_proj: str = None

    qkv_proj: str = None
    qk_proj: str = None
    qa_proj: str = None
    qb_proj: str = None
    kv_proj: str = None
    kva_proj: str = None
    kvb_proj: str = None


@dataclass
class MultiModelKeys(ModelKeys):
    language_model: Union[str, List[str]] = field(default_factory=list)
    aligner: Union[str, List[str]] = field(default_factory=list)
    vision_tower: Union[str, List[str]] = field(default_factory=list)
    generator: Union[str, List[str]] = field(default_factory=list)

    def __post_init__(self):
        for key in ["language_model", "aligner", "vision_tower", "generator"]:
            v = getattr(self, key)
            if isinstance(v, str):
                setattr(self, key, [v])
            if v is None:
                setattr(self, key, [])


MODEL_ARCH_MAPPING = {}


def register_model_arch(model_arch: ModelKeys, *, exist_ok: bool = False) -> None:
    """
    model_type: The unique ID for the model type. Models with the same model_type share
        the same architectures, template, get_function, etc.
    """
    arch_name = model_arch.arch_name
    if not exist_ok and arch_name in MODEL_ARCH_MAPPING:
        raise ValueError(
            f"The `{arch_name}` has already been registered in the MODEL_ARCH_MAPPING."
        )

    MODEL_ARCH_MAPPING[arch_name] = model_arch


register_model_arch(
    ModelKeys(
        LLMModelArch.llama,
        module_list="model.layers",
        mlp="model.layers.{}.mlp",
        down_proj="model.layers.{}.mlp.down_proj",
        attention="model.layers.{}.self_attn",
        o_proj="model.layers.{}.self_attn.o_proj",
        q_proj="model.layers.{}.self_attn.q_proj",
        k_proj="model.layers.{}.self_attn.k_proj",
        v_proj="model.layers.{}.self_attn.v_proj",
        embedding="model.embed_tokens",
        lm_head="lm_head",
    )
)

register_model_arch(
    ModelKeys(
        LLMModelArch.internlm2,
        module_list="model.layers",
        mlp="model.layers.{}.feed_forward",
        down_proj="model.layers.{}.feed_forward.w2",
        attention="model.layers.{}.attention",
        o_proj="model.layers.{}.attention.wo",
        qkv_proj="model.layers.{}.attention.wqkv",
        embedding="model.tok_embeddings",
        lm_head="output",
    )
)

register_model_arch(
    ModelKeys(
        LLMModelArch.chatglm,
        module_list="transformer.encoder.layers",
        mlp="transformer.encoder.layers.{}.mlp",
        down_proj="transformer.encoder.layers.{}.mlp.dense_4h_to_h",
        attention="transformer.encoder.layers.{}.self_attention",
        o_proj="transformer.encoder.layers.{}.self_attention.dense",
        qkv_proj="transformer.encoder.layers.{}.self_attention.query_key_value",
        embedding="transformer.embedding",
        lm_head="transformer.output_layer",
    )
)

register_model_arch(
    ModelKeys(
        LLMModelArch.telechat,
        module_list="transformer.h",
        mlp="transformer.h.{}.mlp",
        down_proj="transformer.h.{}.mlp.down_proj",
        attention="transformer.h.{}.self_attention",
        o_proj="transformer.h.{}.self_attention.dense",
        q_proj="transformer.h.{}.self_attention.query",
        kv_proj="transformer.h.{}.self_attention.key_value",
        embedding="transformer.word_embeddings",
        lm_head="lm_head",
    )
)

register_model_arch(
    ModelKeys(
        LLMModelArch.baichuan,
        module_list="model.layers",
        mlp="model.layers.{}.mlp",
        down_proj="model.layers.{}.mlp.down_proj",
        attention="model.layers.{}.self_attn",
        qkv_proj="model.layers.{}.self_attn.W_pack",
        embedding="model.embed_tokens",
        lm_head="lm_head",
    )
)

register_model_arch(
    ModelKeys(
        LLMModelArch.yuan,
        module_list="model.layers",
        mlp="model.layers.{}.mlp",
        down_proj="model.layers.{}.mlp.down_proj",
        attention="model.layers.{}.self_attn",
        qk_proj="model.layers.{}.self_attn.qk_proj",
        o_proj="model.layers.{}.self_attn.o_proj",
        q_proj="model.layers.{}.self_attn.q_proj",
        k_proj="model.layers.{}.self_attn.k_proj",
        v_proj="model.layers.{}.self_attn.v_proj",
        embedding="model.embed_tokens",
        lm_head="lm_head",
    )
)

register_model_arch(
    ModelKeys(
        LLMModelArch.codefuse,
        module_list="gpt_neox.layers",
        mlp="gpt_neox.layers.{}.mlp",
        down_proj="gpt_neox.layers.{}.mlp.dense_4h_to_h",
        attention="gpt_neox.layers.{}.attention",
        o_proj="gpt_neox.layers.{}.attention.dense",
        qkv_proj="gpt_neox.layers.{}.attention.query_key_value",
        embedding="gpt_neox.embed_in",
        lm_head="gpt_neox.embed_out",
    )
)

register_model_arch(
    ModelKeys(
        LLMModelArch.phi2,
        module_list="model.layers",
        mlp="model.layers.{}.mlp",
        down_proj="model.layers.{}.mlp.fc2",
        attention="model.layers.{}.self_attn",
        o_proj="model.layers.{}.self_attn.dense",
        q_proj="model.layers.{}.self_attn.q_proj",
        k_proj="model.layers.{}.self_attn.k_proj",
        v_proj="model.layers.{}.self_attn.v_proj",
        embedding="model.embed_tokens",
        lm_head="lm_head",
    )
)

register_model_arch(
    ModelKeys(
        LLMModelArch.qwen,
        module_list="transformer.h",
        mlp="transformer.h.{}.mlp",
        down_proj="transformer.h.{}.mlp.c_proj",
        attention="transformer.h.{}.attn",
        o_proj="transformer.h.{}.attn.c_proj",
        qkv_proj="transformer.h.{}.attn.c_attn",
        embedding="transformer.wte",
        lm_head="lm_head",
    )
)

register_model_arch(
    ModelKeys(
        LLMModelArch.dbrx,
        module_list="transformer.blocks",
        mlp="transformer.blocks.{}.ffn",
        attention="transformer.blocks.{}.norm_attn_norm.attn",
        o_proj="transformer.blocks.{}.norm_attn_norm.attn.out_proj",
        qkv_proj="transformer.blocks.{}.norm_attn_norm.attn.Wqkv",
        embedding="transformer.wte",
        lm_head="lm_head",
    )
)

register_model_arch(
    ModelKeys(
        LLMModelArch.phi3,
        module_list="model.layers",
        mlp="model.layers.{}.mlp",
        down_proj="model.layers.{}.mlp.down_proj",
        attention="model.layers.{}.self_attn",
        o_proj="model.layers.{}.self_attn.o_proj",
        qkv_proj="model.layers.{}.self_attn.qkv_proj",
        embedding="model.embed_tokens",
        lm_head="lm_head",
    )
)

register_model_arch(
    ModelKeys(
        LLMModelArch.phi3_small,
        module_list="model.layers",
        mlp="model.layers.{}.mlp",
        down_proj="model.layers.{}.mlp.down_proj",
        attention="model.layers.{}.self_attn",
        o_proj="model.layers.{}.self_attn.dense",
        qkv_proj="model.layers.{}.self_attn.query_key_value",
        embedding="model.embed_tokens",
        lm_head="lm_head",
    )
)

register_model_arch(
    ModelKeys(
        LLMModelArch.deepseek_v2,
        module_list="model.layers",
        mlp="model.layers.{}.mlp",
        down_proj="model.layers.{}.mlp.down_proj",
        attention="model.layers.{}.self_attn",
        o_proj="model.layers.{}.self_attn.o_proj",
        qa_proj="model.layers.{}.self_attn.q_a_proj",
        qb_proj="model.layers.{}.self_attn.q_b_proj",
        kva_proj="model.layers.{}.self_attn.kv_a_proj_with_mqa",
        kvb_proj="model.layers.{}.self_attn.kv_b_proj",
        embedding="model.embed_tokens",
        lm_head="lm_head",
    )
)

if transformers_ge_4_52:
    register_model_arch(
        MultiModelKeys(
            MLLMModelArch.llava_hf,
            language_model="model.language_model",
            aligner="model.multi_modal_projector",
            vision_tower="model.vision_tower",
        )
    )
else:
    register_model_arch(
        MultiModelKeys(
            MLLMModelArch.llava_hf,
            language_model="language_model",
            aligner="multi_modal_projector",
            vision_tower="vision_tower",
        )
    )

register_model_arch(
    MultiModelKeys(
        MLLMModelArch.llava_mistral,
        language_model="model.layers",
        aligner="model.mm_projector",
        vision_tower="model.vision_tower",
    )
)

if transformers_ge_4_52:
    register_model_arch(
        MultiModelKeys(
            MLLMModelArch.llava_next_video_hf,
            language_model="model.language_model",
            aligner=["model.multi_modal_projector"],
            vision_tower="model.vision_tower",
        )
    )
else:
    register_model_arch(
        MultiModelKeys(
            MLLMModelArch.llava_next_video_hf,
            language_model="language_model",
            aligner=["multi_modal_projector"],
            vision_tower="vision_tower",
        )
    )

register_model_arch(
    MultiModelKeys(
        MLLMModelArch.llava_llama,
        language_model="model.layers",
        aligner="model.mm_projector",
        vision_tower="model.vision_tower",
    )
)

register_model_arch(
    MultiModelKeys(
        MLLMModelArch.xcomposer,
        language_model="model",
        aligner="vision_proj",
        vision_tower="vit",
    )
)

register_model_arch(
    MultiModelKeys(
        MLLMModelArch.internvl,
        language_model="language_model",
        aligner="mlp1",
        vision_tower="vision_model",
    )
)

register_model_arch(
    MultiModelKeys(
        MLLMModelArch.mplug_owl3,
        language_model="language_model",
        aligner="vision2text_model",
        vision_tower="vision_model",
    )
)

register_model_arch(
    MultiModelKeys(
        MLLMModelArch.doc_owl2,
        language_model="model.layers",
        aligner=["model.vision2text", "model.hr_compressor"],
        vision_tower="model.vision_model",
    )
)

register_model_arch(
    MultiModelKeys(
        MLLMModelArch.deepseek_vl,
        language_model="language_model",
        aligner="aligner",
        vision_tower="vision_model",
    )
)

register_model_arch(
    MultiModelKeys(
        MLLMModelArch.deepseek_janus,
        language_model="language_model",
        vision_tower="vision_model",
        aligner="aligner",
        generator=["gen_vision_model", "gen_aligner", "gen_head", "gen_embed"],
    )
)

register_model_arch(
    MultiModelKeys(
        MLLMModelArch.deepseek_vl2,
        language_model="language",
        vision_tower="vision",
        aligner="projector",
    )
)

register_model_arch(
    MultiModelKeys(
        MLLMModelArch.minicpmv,
        language_model="llm",
        aligner="resampler",
        vision_tower="vpm",
    )
)

register_model_arch(
    MultiModelKeys(
        MLLMModelArch.phi3_vision,
        language_model="model.layers",
        aligner="model.vision_embed_tokens.img_projection",
        vision_tower="model.vision_embed_tokens.img_processor",
    )
)

register_model_arch(
    MultiModelKeys(
        MLLMModelArch.phi4_multimodal,
        language_model="model.layers",
        aligner=[
            "model.embed_tokens_extend.image_embed.img_projection",
            "model.embed_tokens_extend.audio_embed.audio_projection",
        ],
        vision_tower=[
            "model.embed_tokens_extend.image_embed.img_processor",
            "model.embed_tokens_extend.audio_embed.encoder",
        ],
    )
)

register_model_arch(
    MultiModelKeys(
        MLLMModelArch.cogvlm,
        language_model="model.layers",
        vision_tower="model.vision",
    )
)

register_model_arch(
    MultiModelKeys(
        MLLMModelArch.florence,
        language_model="language_model",
        vision_tower="vision_tower",
    )
)

register_model_arch(
    MultiModelKeys(
        MLLMModelArch.qwen_vl,
        language_model="transformer.h",
        vision_tower="transformer.visual",
    )
)
# TODO: check lm_head, ALL
register_model_arch(
    MultiModelKeys(
        MLLMModelArch.qwen_audio,
        language_model="transformer.h",
        vision_tower="transformer.audio",
    )
)

register_model_arch(
    MultiModelKeys(
        MLLMModelArch.qwen2_audio,
        language_model="language_model",
        aligner="multi_modal_projector",
        vision_tower="audio_tower",
    )
)

if transformers_ge_4_52:
    register_model_arch(
        MultiModelKeys(
            MLLMModelArch.qwen2_vl,
            language_model="model.language_model",
            aligner="model.visual.merger",
            vision_tower="model.visual",
        )
    )
else:
    register_model_arch(
        MultiModelKeys(
            MLLMModelArch.qwen2_vl,
            language_model="model",
            aligner="visual.merger",
            vision_tower="visual",
        )
    )

register_model_arch(
    MultiModelKeys(
        MLLMModelArch.qwen2_5_omni,
        language_model="thinker.model",
        vision_tower=["thinker.audio_tower", "thinker.visual"],
        aligner=["thinker.audio_tower.proj", "thinker.visual.merger"],
        generator=["talker", "token2wav"],
    )
)

register_model_arch(
    MultiModelKeys(
        MLLMModelArch.glm4v,
        language_model="transformer.encoder",
        vision_tower="transformer.vision",
    )
)

register_model_arch(
    MultiModelKeys(
        MLLMModelArch.glm4_1v,
        language_model="model.language_model",
        aligner="model.visual.merger",
        vision_tower="model.visual",
    )
)

register_model_arch(
    MultiModelKeys(
        MLLMModelArch.idefics3,
        language_model="model.text_model",
        aligner="model.connector",
        vision_tower="model.vision_model",
    )
)

register_model_arch(
    MultiModelKeys(
        MLLMModelArch.llama3_1_omni,
        language_model="model.layers",
        aligner="model.speech_projector",
        vision_tower="model.speech_encoder",
        generator="speech_generator",
    )
)

register_model_arch(
    MultiModelKeys(
        MLLMModelArch.got_ocr2,
        language_model="model.layers",
        aligner="model.mm_projector_vary",
        vision_tower="model.vision_tower_high",
    )
)

if transformers_ge_4_52:
    register_model_arch(
        MultiModelKeys(
            MLLMModelArch.llama3_2_vision,
            language_model="model.language_model",
            aligner="model.multi_modal_projector",
            vision_tower="model.vision_model",
        )
    )
else:
    register_model_arch(
        MultiModelKeys(
            MLLMModelArch.llama3_2_vision,
            language_model="language_model",
            aligner="multi_modal_projector",
            vision_tower="vision_model",
        )
    )

register_model_arch(
    MultiModelKeys(
        MLLMModelArch.ovis1_6,
        language_model="llm",
        vision_tower="visual_tokenizer",
    )
)

register_model_arch(
    MultiModelKeys(
        MLLMModelArch.molmo,
        language_model="model.transformer",
        vision_tower="model.vision_backbone",
        aligner="model.vision_backbone.image_projector",
    )
)

register_model_arch(
    MultiModelKeys(
        MLLMModelArch.megrez_omni,
        language_model="llm",
        vision_tower=["vision", "audio"],
    )
)

register_model_arch(MultiModelKeys(MLLMModelArch.emu3_chat, language_model="model"))

register_model_arch(
    MultiModelKeys(
        MLLMModelArch.glm_edge_v,
        language_model="model.layers",
        vision_tower="model.vision",
    )
)

register_model_arch(
    MultiModelKeys(
        MLLMModelArch.valley,
        language_model="model",
        vision_tower=["model.vision_tower", "model.qwen2vl_vision_tower"],
    )
)

register_model_arch(
    MultiModelKeys(
        MLLMModelArch.gemma3n,
        language_model="model.language_model",
        aligner=["model.embed_vision", "model.embed_audio"],
        vision_tower=["model.vision_tower", "model.audio_tower"],
    )
)

register_model_arch(
    MultiModelKeys(
        MLLMModelArch.keye_vl,
        language_model="model",
        aligner="mlp_AR",
        vision_tower="visual",
    )
)


def get_model_arch(arch_name: Optional[str]) -> Optional[MultiModelKeys]:
    return MODEL_ARCH_MAPPING.get(arch_name)
